#!/usr/bin/env python3
import ctypes
import os
import datetime
import sys

class cBenchCounters(ctypes.Structure):
    '''
    This has to match the returned struct in libbag.c
    '''
    _fields_ = [ ("additions", ctypes.c_int),
                 ("failed_removals", ctypes.c_int),
                 ("successful_removals", ctypes.c_int) ]

class cBenchResult(ctypes.Structure):
    '''
    This has to match the returned struct in libbag.c
    '''
    _fields_ = [ ("time", ctypes.c_float),
                 ("add_time", ctypes.c_float),
                 ("try_remove_any_time", ctypes.c_float),
                 ("counters", cBenchCounters) ]

class Benchmark:
    '''
    Class representing a benchmark. It assumes any benchmark sweeps over some
    parameter xrange using the fixed set of inputs for every point. It provides
    two ways of averaging over the given amount of repetitions:
    - represent everything in a boxplot, or
    - average over the results.
    '''
    def __init__(self, bench_function, parameters,
                 repetitions_per_point, xrange, basedir, name, now):
        self.bench_function = bench_function
        self.parameters = parameters
        self.repetitions_per_point = repetitions_per_point
        self.xrange = xrange
        self.basedir = basedir
        self.name = name

        self.result_times = {}
        self.result_add_times = {}
        self.result_try_remove_any_times = {}
        self.result_throughput = {}
        self.result_add_throughput = {}
        self.result_try_remove_any_throughput = {}
        self.now = now

    def __str__(self):
        return f'Benchmark(name={self.name}, xrange={self.xrange}, repetitions={self.repetitions_per_point})'

    def run(self):
        '''
        Runs the benchmark with the given parameters. Collects
        repetitions_per_point data points and writes them back to the data
        dictionary to be processed later.
        '''
        print(f"Starting Benchmark run at {self.now}")

        # for thread count (x-axis)
        for x in self.xrange:
            times = []
            add_times = []
            try_remove_any_times = []
            throughput = []
            add_throughput = []
            try_remove_any_throughput = []
            for r in range(0, self.repetitions_per_point):
                #print(f'run {r} out of {self.repetitions_per_point}')
                # evaluate any lambda function parameters (functions of thread count x)
                evaluated = [x] + [item(x) if callable(item) else item for item in self.parameters]
                # unpack self.parameters tuple (a, b, ...) into parameter list a, b, ...
                result = self.bench_function(*evaluated)
                times.append(result.time * 1000)
                throughput.append(result.counters.successful_removals / result.time)
                if result.add_time > 0:
                    add_times.append(result.add_time * 1000)
                    add_throughput.append(result.counters.additions / result.add_time)
                if result.try_remove_any_time > 0:
                    try_remove_any_times.append(result.try_remove_any_time * 1000)
                    try_remove_any_throughput.append(result.counters.successful_removals / result.try_remove_any_time)
            self.result_times[x] = times
            self.result_throughput[x] = throughput
            if len(add_times) > 0:
                self.result_add_times[x] = add_times
            if len(try_remove_any_times) > 0:
                self.result_try_remove_any_times[x] = try_remove_any_times
            if len(add_throughput) > 0:
                self.result_add_throughput[x] = add_throughput
            if len(try_remove_any_throughput) > 0:
                self.result_try_remove_any_throughput[x] = try_remove_any_throughput

    def write_avg_data(self, extended=False):
        '''
        Writes averages for each point measured into a dataset in the data
        folder timestamped when the run was started.
        '''
        if self.now is None:
            raise Exception("Benchmark was not run or timestamp was not set. Run before writing data.")

        try:
            os.makedirs(f"{self.basedir}/data/{self.now}/avg/time")
            os.makedirs(f"{self.basedir}/data/{self.now}/avg/throughput")
        except FileExistsError:
            pass
        with open(f"{self.basedir}/data/{self.now}/avg/time/{self.name}.data", "w")\
                as datafile:
            datafile.write(f"x datapoint\n")
            for x, box in self.result_times.items():
                datafile.write(f"{x} {sum(box)/len(box)}\n")
        with open(f"{self.basedir}/data/{self.now}/avg/throughput/{self.name}.data", "w")\
                as datafile:
            datafile.write(f"x datapoint\n")
            for x, box in self.result_throughput.items():
                datafile.write(f"{x} {sum(box)/len(box)}\n")
        if len(self.result_add_times) > 0:
            with open(f"{self.basedir}/data/{self.now}/avg/time/{self.name}.add.data", "w")\
                    as datafile:
                datafile.write(f"x datapoint\n")
                for x, box in self.result_add_times.items():
                    datafile.write(f"{x} {sum(box)/len(box)}\n")
            with open(f"{self.basedir}/data/{self.now}/avg/throughput/{self.name}.add.data", "w")\
                    as datafile:
                datafile.write(f"x datapoint\n")
                for x, box in self.result_add_throughput.items():
                    datafile.write(f"{x} {sum(box)/len(box)}\n")
        if len(self.result_try_remove_any_times) > 0:
            with open(f"{self.basedir}/data/{self.now}/avg/time/{self.name}.tra.data", "w")\
                    as datafile:
                datafile.write(f"x datapoint\n")
                for x, box in self.result_try_remove_any_times.items():
                    datafile.write(f"{x} {sum(box)/len(box)}\n")
            with open(f"{self.basedir}/data/{self.now}/avg/throughput/{self.name}.tra.data", "w")\
                    as datafile:
                datafile.write(f"x datapoint\n")
                for x, box in self.result_try_remove_any_throughput.items():
                    datafile.write(f"{x} {sum(box)/len(box)}\n")

def benchmark(parameters):
    '''
    Requires the binary to also be present as a shared library.
    '''
    basedir = os.path.dirname(os.path.abspath(__file__))
    binary = ctypes.CDLL( f"{basedir}/libbag.so" )
    # Set the result type for each benchmark function
    binary.bench_simple_prod_cons.restype = cBenchResult
    binary.bench_cbag_prod_cons.restype = cBenchResult
    binary.bench_simple_add_try_remove_any.restype = cBenchResult
    binary.bench_cbag_add_try_remove_any.restype = cBenchResult

    # try to get parameterized workload, otherwise return default parameter
    arg_workloads = parameters.get('workload', '100,1000,10000')
    # try to get parameterized threads, otherwise return default parameter
    arg_threads = parameters.get('threads', '2,4,8,16,32,64')
    # try to get parameterized repetition count, otherwise return default parameter
    arg_repetitions = parameters.get('repetitions', '2')
    # try to get parameterized file prefix (default=small)
    arg_prefix = parameters.get('prefix', 'small')
    # try to get parameterized flag for whether we are running simple benchmarks
    arg_simple = parameters.get('simple', 'true')
    # try to get parameterized flag for whether we are running cbag benchmarks
    arg_cbag = parameters.get('cbag', 'true')
    # try to get parameterized flag for whether we are running prod cons (half half)
    arg_prod_cons = parameters.get('prod_cons', 'true')
    # try to get parameterized flag for whether we are running 1 prod rest cons
    arg_1_prod = parameters.get('1_prod', 'true')
    # try to get parameterized flag for whether we are running 1 cons rest prod
    arg_1_cons = parameters.get('1_cons', 'true')
    # try to get parameterized flag for whether we are running add/try_remove_any
    arg_add_tra = parameters.get('add_try_remove_any', 'true')

    # The number of elements to benchmark
    workloads = [int(n) for n in arg_workloads.split(',')]
    # The number of threads. This is the x-axis in the benchmark, i.e., the
    # parameter that is 'sweeped' over.
    num_threads = [int(t) for t in arg_threads.split(',')]
    # How often to repeat one benchmark
    repetitions = int(arg_repetitions)
    # use simple
    simple = arg_simple == 'true'
    # use cbag
    cbag = arg_cbag == 'true'
    # use prod_cons
    use_prod_cons = arg_prod_cons == 'true'
    # use 1_prod
    use_1_prod = arg_1_prod == 'true'
    # use 1_cons
    use_1_cons = arg_1_cons == 'true'
    # use add_try_remove_any
    use_add_tra = arg_add_tra == 'true'

    # The timestamp to mark every benchmark with
    #now = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    now = datetime.datetime.now().strftime("%Y-%m-%d")

    benchmarks = []
    for n in workloads:
        # Parameters for the benchmark are passed in a tuple, here (1000,). To pass
        # just one parameter, we cannot write (1000) because that would not parse
        # as a tuple, instead python understands a trailing comma as a tuple with
        # just one entry.
        # 30 repetitions per point

        if simple:
            if use_prod_cons:
                # benchmark prod/cons with n/2 threads out of n assigned consumer
                benchmarks.append(Benchmark(binary.bench_simple_prod_cons, (n, lambda x: x//2), repetitions, num_threads, basedir, f"{arg_prefix}bench_simple_prod_cons_{n}", now))
            if use_1_cons:
                # benchmark prod/cons with 1 threads out of n assigned consumer
                benchmarks.append(Benchmark(binary.bench_simple_prod_cons, (n, lambda x: 1), repetitions, num_threads, basedir, f"{arg_prefix}bench_simple_1_cons_{n}", now))
            if use_1_prod:
                # benchmark prod/cons with n - 1 threads out of n assigned consumer
                benchmarks.append(Benchmark(binary.bench_simple_prod_cons, (n, lambda x: x - 1), repetitions, num_threads, basedir, f"{arg_prefix}bench_simple_1_prod_{n}", now))
            if use_add_tra:
                # benchmark add and try_remove_any times
                benchmarks.append(Benchmark(binary.bench_simple_add_try_remove_any, (n,), repetitions, num_threads, basedir, f"{arg_prefix}bench_simple_add_tra_{n}", now))

        if cbag:
            if use_prod_cons:
                # benchmark prod/cons with n/2 threads out of n assigned consumer
                benchmarks.append(Benchmark(binary.bench_cbag_prod_cons, (n, lambda x: x//2), repetitions, num_threads, basedir, f"{arg_prefix}bench_cbag_prod_cons_{n}", now))
            if use_1_cons:
                # benchmark prod/cons with 1 threads out of n assigned consumer
                benchmarks.append(Benchmark(binary.bench_cbag_prod_cons, (n, lambda x: 1), repetitions, num_threads, basedir, f"{arg_prefix}bench_cbag_1_cons_{n}", now))
            if use_1_prod:
                # benchmark prod/cons with n - 1 threads out of n assigned consumer
                benchmarks.append(Benchmark(binary.bench_cbag_prod_cons, (n, lambda x: x - 1), repetitions, num_threads, basedir, f"{arg_prefix}bench_cbag_1_prod_{n}", now))
            if use_add_tra:
                # benchmark add and try_remove_any times
                benchmarks.append(Benchmark(binary.bench_cbag_add_try_remove_any, (n,), repetitions, num_threads, basedir, f"{arg_prefix}bench_cbag_add_tra_{n}", now))

    for benchmark in benchmarks:
        benchmark.run()
        benchmark.write_avg_data()
        print(f'successfully ran benchmark {benchmark}')
    for benchmark in benchmarks:
        print(benchmark)

if __name__ == "__main__":
    # convert every double pair of additional command line arguments into key-value pair in dict
    additional_args = sys.argv[1:]
    parameters = {key.strip('-'): value for key, value in zip(additional_args[::2], additional_args[1::2])}
    benchmark(parameters)
